package problem212_Word_Search_II;

import java.util.LinkedList;
import java.util.List;

public class Solution {
	private static final int[] dirX={0,1,0,-1};
	private static final int[] dirY={1,0,-1,0};
	
	public List<String> findWords(char[][] board, String[] words) {
		boolean[][] visited=new boolean[board.length][board[0].length];
		List<String> result=new LinkedList<>();
		Trie trie=new Trie();
		for(String s:words)
			trie.insert(s);
		for(int i=0;i<board.length;i++){
			for(int j=0;j<board[i].length;j++)
				dfs(i,j,trie,"",result,visited,board);
		}
		return result;
    }
	
	private void dfs(int i,int j,Trie trie,String s,List<String> result,boolean[][] visited,char[][] board){
		if(i<0||i>=board.length||j<0||j>=board[i].length||visited[i][j])
			return;
		s=s+board[i][j];
		if(!trie.startsWith(s))
			return;
		if(trie.search(s)){
			if(!result.contains(s))
				result.add(s);
		}
		visited[i][j]=true;
		for(int k=0;k<dirX.length;k++){
			dfs(i+dirX[k],j+dirY[k],trie,s,result,visited,board);
		}
		visited[i][j]=false;
	}
}

class TrieNode {
	char val;
	TrieNode[] sons;
	boolean isEnd;

	public TrieNode() {
		sons = new TrieNode[26];
		isEnd = false;
	}

	public TrieNode(char c) {
		sons = new TrieNode[26];
		val = c;
		isEnd = false;
	}
}

class Trie {
	private TrieNode root;

	public Trie() {
		root = new TrieNode();
	}

	public void insert(String word) {
		char[] arr = word.toCharArray();
		TrieNode node = root;
		for (char c : arr) {
			int pos = c - 'a';
			if (node.sons[pos] == null) {
				node.sons[pos] = new TrieNode(c);
			}
			node = node.sons[pos];
		}
		node.isEnd = true;
	}

	public boolean search(String word) {
		char[] arr = word.toCharArray();
		TrieNode node = root;
		for (char c : arr) {
			int pos = c - 'a';
			if (node.sons[pos] == null) {
				return false;
			} else {
				node = node.sons[pos];
			}
		}
		return node.isEnd;
	}

	public boolean startsWith(String prefix) {
		char[] arr = prefix.toCharArray();
		TrieNode node = root;
		for (char c : arr) {
			int pos = c - 'a';
			if (node.sons[pos] == null) {
				return false;
			}
			node=node.sons[pos];
		}
		return true;
	}
}
